// This Pine Script™ code is subject to the terms of the Mozilla Public License 2.0 at https://mozilla.org/MPL/2.0/
// © AlgoAlpha

//@version=5
indicator("Adaptive SuperTrend Oscillator [AlgoAlpha]", "AlgoAlpha - 🤖 SuperTrend Oscillator", overlay = false, max_labels_count = 500)
import TradingView/ta/7
atr_len = input.int(10, "ATR Length", group = "SuperTrend Settings")
fact = input.float(3, "SuperTrend Factor", group = "SuperTrend Settings")
training_data_period = input.int(70, "Training Data Length", group = "K-Means Settings")
highvol = input.float(0.75, "Initial High volatility Percentile Guess", maxval = 1, group = "K-Means Settings", tooltip = "The initial guess of where the potential 'high volatility' area is, a value of 0.75 will take the 75th percentile of the range of ATR values over the training data period")
midvol = input.float(0.5, "Initial Medium volatility Percentile Guess", maxval = 1, group = "K-Means Settings", tooltip = "The initial guess of where the potential 'medium volatility' area is, a value of 0.5 will take the 50th percentile of the range of ATR values over the training data period")
lowvol = input.float(0.25, "Initial Low volatility Percentile Guess", maxval = 1, group = "K-Means Settings", tooltip = "The initial guess of where the potential 'low volatility' area is, a value of 0.25 will take the 25th percentile of the range of ATR values over the training data period")
norm = input.bool(true, "Normalize Oscillator", "If true, the Oscillator will be normalized to be within a fixed range", group = "Oscillator Settings")
smoothed = input.bool(true, "Smooth Oscillator", "If true, the Oscillator will be normalized to be within a fixed range", group = "Oscillator Settings")
HA = input.bool(true, "Display as Heiken Ashi", "If true, the Oscillator will be normalized to be within a fixed range", group = "Oscillator Settings")
smthtype = input.string("SMA", "Oscillator Smoothing Method", ["SMA", "EMA"], "What method is used to smooth the Oscillator", group = "Oscillator Settings")
smthlen = input.int(7, "Smoothing Length", tooltip = "The MA Length for smoothing the Oscillator", group = "Oscillator Settings")
normlen = input.int(300, "Normalization Length", tooltip = "The Normalization Length for Normalizing the Oscillator", group = "Oscillator Settings")
trnd = input.int(7, "Trend Lag", tooltip = "The lag used to identify the trend of the Oscillator", group = "Oscillator Settings")
green = input.color(#00ffbb, "Bullish Color", group = "Appearance")
red = input.color(#ff1100, "Bearish Color", group = "Appearance")

pine_supertrend(factor, atr) =>
    src = hl2
    upperBand = src + factor * atr
    lowerBand = src - factor * atr
    prevLowerBand = nz(lowerBand[1])
    prevUpperBand = nz(upperBand[1])

    lowerBand := lowerBand > prevLowerBand or close[1] < prevLowerBand ? lowerBand : prevLowerBand
    upperBand := upperBand < prevUpperBand or close[1] > prevUpperBand ? upperBand : prevUpperBand
    int _direction = na
    float superTrend = na
    prevSuperTrend = superTrend[1]
    if na(atr[1])
        _direction := 1
    else if prevSuperTrend == prevUpperBand
        _direction := close > upperBand ? -1 : 1
    else
        _direction := close < lowerBand ? 1 : -1
    superTrend := _direction == -1 ? lowerBand : upperBand
    [superTrend, _direction, upperBand, lowerBand]

volatility = ta.atr(atr_len)

upper = ta.highest(volatility, training_data_period)
lower = ta.lowest(volatility, training_data_period)

high_volatility = lower + (upper-lower) * highvol
medium_volatility = lower + (upper-lower) * midvol
low_volatility = lower + (upper-lower) * lowvol

iterations = 0

size_a = 0
size_b = 0
size_c = 0

hv = array.new_float()
mv = array.new_float()
lv = array.new_float()
amean = array.new_float(1,high_volatility)
bmean = array.new_float(1,medium_volatility)
cmean = array.new_float(1,low_volatility)

if nz(volatility) > 0 and bar_index >= training_data_period-1

    while ((amean.size() == 1 ? true : (amean.first() != amean.get(1))) or (bmean.size() == 1 ? true : (bmean.first() != bmean.get(1))) or (cmean.size() == 1 ? true : (cmean.first() != cmean.get(1))))
        hv.clear()
        mv.clear()
        lv.clear()
        for i = training_data_period-1 to 0
            _1 = math.abs(volatility[i] - amean.first())
            _2 = math.abs(volatility[i] - bmean.first())
            _3 = math.abs(volatility[i] - cmean.first())
            if _1 < _2 and _1 < _3
                hv.unshift(volatility[i])

            if _2 < _1 and _2 < _3
                mv.unshift(volatility[i])

            if _3 < _1 and _3 < _2
                lv.unshift(volatility[i])
        
        amean.unshift(hv.avg())
        bmean.unshift(mv.avg())
        cmean.unshift(lv.avg())
        size_a := hv.size()
        size_b := mv.size()
        size_c := lv.size()
        iterations := iterations + 1

hv_new = amean.first()
mv_new = bmean.first()
lv_new = cmean.first()
vdist_a = math.abs(volatility - hv_new)
vdist_b = math.abs(volatility - mv_new)
vdist_c = math.abs(volatility - lv_new)

distances = array.new_float()
centroids = array.new_float()

distances.push(vdist_a)
distances.push(vdist_b)
distances.push(vdist_c)

centroids.push(hv_new)
centroids.push(mv_new)
centroids.push(lv_new)

cluster = distances.indexof(distances.min()) // 0 for high, 1 for medium, 2 for low
assigned_centroid = cluster == -1 ? na : centroids.get(cluster)

[ST, dir, oscupper, osclower] = pine_supertrend(fact, assigned_centroid)

raw_osc = norm ? ((close-ST)-ta.lowest(close-ST, normlen))/(ta.highest(close-ST, normlen)-ta.lowest(close-ST, normlen)) - 0.5 : close-ST
smthed_osc = smoothed ? (smthtype == "EMA" ? ta.ema(raw_osc, smthlen) : ta.sma(raw_osc, smthlen)) : raw_osc

HA(float val) => //HA transform
    o_rsi = val[1]
    h_rsi = math.max(val, val[1])
    l_rsi = math.min(val, val[1])
    c_rsi = val

    haClose = (o_rsi + h_rsi + l_rsi + c_rsi) / 4
    haOpen = float(na)
    haOpen := na(haOpen[1]) ? (o_rsi + c_rsi) / 2 : (nz(haOpen[1]) + nz(haClose[1])) / 2
    haHigh = math.max(h_rsi, math.max(haOpen, haClose))
    haLow = math.min(l_rsi, math.min(haOpen, haClose))

    [haOpen, haHigh, haLow, haClose]

[o,h,l,c] = HA(smthed_osc)

deviations = ta.stdev(smthed_osc, 120, false)

lead = plot(smthed_osc, color = smthed_osc > smthed_osc[trnd] ? green : red, title = "Oscillator Plot", display = HA ? display.none : display.all)
lag = plot(smthed_osc[trnd], color = color.yellow, title = "Oscillator Signal", display = display.none)
plotcandle(o,h,l,c,"Heiken Ashi Candles", c > o ? color.new(green, 12) : color.new(red, 40), c > o ? color.new(green, 12) : color.new(red, 40), bordercolor = c > o ? color.new(green, 12) : color.new(red, 40), display = HA ? display.all : display.none)
mid = plot(0, "Midline", chart.fg_color)
u1 = plot(deviations*2, display = display.none)
u2 = plot(deviations*3, "Upper Band Extention", red)
l1 = plot(-deviations*2, display = display.none)
l2 = plot(-deviations*3, "Lower Band Extention", green)

fill(lead, lag, smthed_osc > smthed_osc[trnd] ? color.new(green, 70) : color.new(red, 70), display = HA ? display.none : display.all)
fill(u1, u2, deviations*3, deviations*2, color.new(red, 10), color.new(red, 97))
fill(l1, l2, -deviations*3, -deviations*2, color.new(green, 10), color.new(green, 97))
fill(mid, lead, 0, smthed_osc, color.new(smthed_osc > 0 ? green : red, 80), color.new(smthed_osc > 0 ? green : red, 97))

plotchar(ta.crossunder(smthed_osc, deviations*2) ? deviations*4 : na, "Bearish Reversal", "▼", location.absolute, red, size = size.tiny)
plotchar(ta.crossover(smthed_osc, -deviations*2) ? -deviations*4 : na, "Bullish Reversal", "▲", location.absolute, green, size = size.tiny)

if barstate.islast
    var data_table = table.new(position=position.bottom_right, columns=4, rows=4, bgcolor = chart.bg_color, border_width=1, border_color = chart.fg_color, frame_color = chart.fg_color, frame_width = 1)
    table.cell(data_table, text_halign=text.align_center, column=0, row=0, text="Cluster Number (Volatility Level)", text_color = chart.fg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=1, row=0, text="Cluster Centroid (ATR)", text_color = chart.fg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=2, row=0, text="Cluster Size", text_color = chart.fg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=3, row=0, text="Current Volatility", text_color = chart.fg_color, text_size = size.tiny)

    table.cell(data_table, text_halign=text.align_center, column=0, row=1, text="3 (High)", text_color = chart.fg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=0, row=2, text= "2 (Medium)", text_color = chart.fg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=0, row=3, text= "1 (Low)", text_color = chart.fg_color, text_size = size.tiny)

    table.cell(data_table, text_halign=text.align_center, column=1, row=1, text=str.format("{0,number,#.##}", hv_new), text_color = chart.fg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=1, row=2, text=str.format("{0,number,#.##}", mv_new), text_color = chart.fg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=1, row=3, text=str.format("{0,number,#.##}", lv_new), text_color = chart.fg_color, text_size = size.tiny)

    table.cell(data_table, text_halign=text.align_center, column=2, row=1, text=str.format("{0,number,#.##}", size_c), text_color = chart.fg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=2, row=2, text=str.format("{0,number,#.##}", size_b), text_color = chart.fg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=2, row=3, text=str.format("{0,number,#.##}", size_a), text_color = chart.fg_color, text_size = size.tiny)

    table.cell(data_table, text_halign=text.align_center, column=3, row=1, text="HIGH " + "(ATR: " + str.format("{0,number,#.##}", volatility) + ")", text_color = chart.bg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=3, row=2, text="MEDIUM " + "(ATR: " + str.format("{0,number,#.##}", volatility) + ")", text_color = chart.bg_color, text_size = size.tiny)
    table.cell(data_table, text_halign=text.align_center, column=3, row=3, text="LOW " + "(ATR: " + str.format("{0,number,#.##}", volatility) + ")", text_color = chart.bg_color, text_size = size.tiny)

    if cluster == 0
        data_table.cell_set_bgcolor(3, 1, chart.fg_color)
    else
        data_table.cell_set_bgcolor(3, 1, chart.bg_color)

    if cluster == 1
        data_table.cell_set_bgcolor(3, 2, chart.fg_color)
    else
        data_table.cell_set_bgcolor(3, 2, chart.bg_color)

    if cluster == 2
        data_table.cell_set_bgcolor(3, 3, chart.fg_color)
    else
        data_table.cell_set_bgcolor(3, 3, chart.bg_color)

////////////////////////////Alerts
alertcondition(ta.crossunder(smthed_osc, 0) and barstate.isconfirmed, "Long-Term Bullish Trend Shift")
alertcondition(ta.crossover(smthed_osc, 0) and barstate.isconfirmed, "Long-Term Bearish Trend Shift")
alertcondition(ta.crossunder(smthed_osc, smthed_osc[trnd]) and barstate.isconfirmed, "Short-Term Bullish Trend Shift")
alertcondition(ta.crossover(smthed_osc, smthed_osc[trnd]) and barstate.isconfirmed, "Short-Term Bearish Trend Shift")
alertcondition(ta.crossunder(smthed_osc, deviations*2) and barstate.isconfirmed, "Bearish Reversal")
alertcondition(ta.crossover(smthed_osc, -deviations*2) and barstate.isconfirmed, "Bullish Reversal")
alertcondition(cluster == 0 and cluster[1] != 0 and barstate.isconfirmed, "High Volatility")
alertcondition(cluster == 1 and cluster[1] != 1 and barstate.isconfirmed, "Medium Volatility")
alertcondition(cluster == 2 and cluster[1] != 2 and barstate.isconfirmed, "Low Volatility")
alertcondition(cluster == 2 and cluster[1] != 2 and barstate.isconfirmed, "Low Volatility")